{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 序列标注\n",
    "\n",
    "这一部分的内容主要展示如何使用fastNLP实现序列标注(Sequence labeling)任务。您可以使用fastNLP的各个组件快捷，方便地完成序列标注任务，达到出色的效果。 在阅读这篇教程前，希望您已经熟悉了fastNLP的基础使用，尤其是数据的载入以及模型的构建，通过这个小任务的能让您进一步熟悉fastNLP的使用。\n",
    "\n",
    "## 命名实体识别(name entity recognition, NER)\n",
    "\n",
    "命名实体识别任务是从文本中抽取出具有特殊意义或者指代性非常强的实体，通常包括人名、地名、机构名和时间等。 如下面的例子中\n",
    "\n",
    "*我来自复旦大学*\n",
    "\n",
    "其中“复旦大学”就是一个机构名，命名实体识别就是要从中识别出“复旦大学”这四个字是一个整体，且属于机构名这个类别。这个问题在实际做的时候会被 转换为序列标注问题\n",
    "\n",
    "针对\"我来自复旦大学\"这句话，我们的预测目标将是[O, O, O, B-ORG, I-ORG, I-ORG, I-ORG]，其中O表示out,即不是一个实体，B-ORG是ORG( organization的缩写)这个类别的开头(Begin)，I-ORG是ORG类别的中间(Inside)。\n",
    "\n",
    "在本tutorial中我们将通过fastNLP尝试写出一个能够执行以上任务的模型。\n",
    "\n",
    "## 载入数据\n",
    "\n",
    "fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的，您可以通过《使用Loader和Pipe处理数据》了解如何使用fastNLP提供的数据加载函数。下面我们以微博命名实体任务来演示一下在fastNLP进行序列标注任务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n",
      "| raw_chars                         | target                            | chars                             | seq_len |\n",
      "+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n",
      "| ['科', '技', '全', '方', '位',... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... | [792, 1015, 156, 198, 291, 714... | 26      |\n",
      "| ['对', '，', '输', '给', '一',... | [0, 0, 0, 0, 0, 0, 3, 1, 0, 0,... | [123, 2, 1205, 115, 8, 24, 101... | 15      |\n",
      "+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.io import WeiboNERPipe\n",
    "data_bundle = WeiboNERPipe().process_from_file()\n",
    "print(data_bundle.get_dataset('train')[:2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型构建\n",
    "\n",
    "首先选择需要使用的Embedding类型。关于Embedding的相关说明可以参见《使用Embedding模块将文本转成向量》。 在这里我们使用通过word2vec预训练的中文汉字embedding。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 3321 out of 3471 words in the pre-training embedding.\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import StaticEmbedding\n",
    "\n",
    "embed = StaticEmbedding(vocab=data_bundle.get_vocab('chars'), model_dir_or_name='cn-char-fastnlp-100d')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "选择好Embedding之后，我们可以使用fastNLP中自带的 fastNLP.models.BiLSTMCRF 作为模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastNLP.models import BiLSTMCRF\n",
    "\n",
    "data_bundle.rename_field('chars', 'words')  # 这是由于BiLSTMCRF模型的forward函数接受的words，而不是chars，所以需要把这一列重新命名\n",
    "model = BiLSTMCRF(embed=embed, num_classes=len(data_bundle.get_vocab('target')), num_layers=1, hidden_size=200, dropout=0.5,\n",
    "              target_vocab=data_bundle.get_vocab('target'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 进行训练\n",
    "下面我们选择用来评估模型的metric，以及优化用到的优化函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastNLP import SpanFPreRecMetric\n",
    "from torch.optim import Adam\n",
    "from fastNLP import LossInForward\n",
    "\n",
    "metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))\n",
    "optimizer = Adam(model.parameters(), lr=1e-2)\n",
    "loss = LossInForward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用Trainer进行训练, 您可以通过修改 device 的值来选择显卡。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input fields after batch(if batch size is 2):\n",
      "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
      "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
      "target fields after batch(if batch size is 2):\n",
      "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
      "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\n",
      "training epochs started 2020-02-27-13-53-24\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=430.0), HTML(value='')), layout=Layout(di…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.89 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 1/10. Step:43/430: \n",
      "\r",
      "SpanFPreRecMetric: f=0.067797, pre=0.192771, rec=0.041131\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.9 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 2/10. Step:86/430: \n",
      "\r",
      "SpanFPreRecMetric: f=0.344086, pre=0.568047, rec=0.246787\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.88 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 3/10. Step:129/430: \n",
      "\r",
      "SpanFPreRecMetric: f=0.446701, pre=0.653465, rec=0.339332\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.81 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 4/10. Step:172/430: \n",
      "\r",
      "SpanFPreRecMetric: f=0.479871, pre=0.642241, rec=0.383033\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.91 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 5/10. Step:215/430: \n",
      "\r",
      "SpanFPreRecMetric: f=0.486312, pre=0.650862, rec=0.388175\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.87 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 6/10. Step:258/430: \n",
      "\r",
      "SpanFPreRecMetric: f=0.541401, pre=0.711297, rec=0.437018\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.86 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 7/10. Step:301/430: \n",
      "\r",
      "SpanFPreRecMetric: f=0.430335, pre=0.685393, rec=0.313625\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.82 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 8/10. Step:344/430: \n",
      "\r",
      "SpanFPreRecMetric: f=0.477759, pre=0.665138, rec=0.372751\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.81 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 9/10. Step:387/430: \n",
      "\r",
      "SpanFPreRecMetric: f=0.500759, pre=0.611111, rec=0.424165\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.8 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 10/10. Step:430/430: \n",
      "\r",
      "SpanFPreRecMetric: f=0.496025, pre=0.65, rec=0.401028\n",
      "\n",
      "\r\n",
      "In Epoch:6/Step:258, got best dev performance:\n",
      "SpanFPreRecMetric: f=0.541401, pre=0.711297, rec=0.437018\n",
      "Reloaded the best model.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'best_eval': {'SpanFPreRecMetric': {'f': 0.541401,\n",
       "   'pre': 0.711297,\n",
       "   'rec': 0.437018}},\n",
       " 'best_epoch': 6,\n",
       " 'best_step': 258,\n",
       " 'seconds': 121.39}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from fastNLP import Trainer\n",
    "import torch\n",
    "\n",
    "device= 0 if torch.cuda.is_available() else 'cpu'\n",
    "trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,\n",
    "                    dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 进行测试\n",
    "训练结束之后过，可以通过 Tester 测试其在测试集上的性能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=17.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 1.54 seconds!\n",
      "[tester] \n",
      "SpanFPreRecMetric: f=0.439024, pre=0.685279, rec=0.322967\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'SpanFPreRecMetric': {'f': 0.439024, 'pre': 0.685279, 'rec': 0.322967}}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from fastNLP import Tester\n",
    "tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)\n",
    "tester.test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用更强的Bert做序列标注\n",
    "\n",
    "在fastNLP使用Bert进行任务，您只需要把fastNLP.embeddings.StaticEmbedding 切换为 fastNLP.embeddings.BertEmbedding（可修改 device 选择显卡）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n",
      "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n",
      "Start to generate word pieces for word.\n",
      "Found(Or segment into word pieces) 3384 words out of 3471.\n",
      "input fields after batch(if batch size is 2):\n",
      "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
      "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
      "target fields after batch(if batch size is 2):\n",
      "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
      "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\n",
      "training epochs started 2020-02-27-13-58-51\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1130.0), HTML(value='')), layout=Layout(d…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate data in 2.7 seconds!\n",
      "Evaluation on dev at Epoch 1/10. Step:113/1130: \n",
      "SpanFPreRecMetric: f=0.008114, pre=0.019231, rec=0.005141\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate data in 2.49 seconds!\n",
      "Evaluation on dev at Epoch 2/10. Step:226/1130: \n",
      "SpanFPreRecMetric: f=0.467866, pre=0.467866, rec=0.467866\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate data in 2.6 seconds!\n",
      "Evaluation on dev at Epoch 3/10. Step:339/1130: \n",
      "SpanFPreRecMetric: f=0.566879, pre=0.482821, rec=0.686375\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate data in 2.56 seconds!\n",
      "Evaluation on dev at Epoch 4/10. Step:452/1130: \n",
      "SpanFPreRecMetric: f=0.651972, pre=0.59408, rec=0.722365\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 2.69 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 5/10. Step:565/1130: \n",
      "\r",
      "SpanFPreRecMetric: f=0.640909, pre=0.574338, rec=0.724936\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate data in 2.52 seconds!\n",
      "Evaluation on dev at Epoch 6/10. Step:678/1130: \n",
      "SpanFPreRecMetric: f=0.661836, pre=0.624146, rec=0.70437\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate data in 2.67 seconds!\n",
      "Evaluation on dev at Epoch 7/10. Step:791/1130: \n",
      "SpanFPreRecMetric: f=0.683429, pre=0.615226, rec=0.768638\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 2.37 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 8/10. Step:904/1130: \n",
      "\r",
      "SpanFPreRecMetric: f=0.674699, pre=0.634921, rec=0.719794\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate data in 2.42 seconds!\n",
      "Evaluation on dev at Epoch 9/10. Step:1017/1130: \n",
      "SpanFPreRecMetric: f=0.693878, pre=0.650901, rec=0.742931\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 2.46 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 10/10. Step:1130/1130: \n",
      "\r",
      "SpanFPreRecMetric: f=0.686845, pre=0.62766, rec=0.758355\n",
      "\n",
      "\r\n",
      "In Epoch:9/Step:1017, got best dev performance:\n",
      "SpanFPreRecMetric: f=0.693878, pre=0.650901, rec=0.742931\n",
      "Reloaded the best model.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=17.0), HTML(value='')), layout=Layout(dis…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 1.96 seconds!\n",
      "[tester] \n",
      "SpanFPreRecMetric: f=0.626561, pre=0.596112, rec=0.660287\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'SpanFPreRecMetric': {'f': 0.626561, 'pre': 0.596112, 'rec': 0.660287}}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "from fastNLP.io import WeiboNERPipe\n",
    "data_bundle = WeiboNERPipe().process_from_file()\n",
    "data_bundle.rename_field('chars', 'words')\n",
    "\n",
    "from fastNLP.embeddings import BertEmbedding\n",
    "embed = BertEmbedding(vocab=data_bundle.get_vocab('words'), model_dir_or_name='cn')\n",
    "model = BiLSTMCRF(embed=embed, num_classes=len(data_bundle.get_vocab('target')), num_layers=1, hidden_size=200, dropout=0.5,\n",
    "              target_vocab=data_bundle.get_vocab('target'))\n",
    "\n",
    "from fastNLP import SpanFPreRecMetric\n",
    "from torch.optim import Adam\n",
    "from fastNLP import LossInForward\n",
    "metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))\n",
    "optimizer = Adam(model.parameters(), lr=2e-5)\n",
    "loss = LossInForward()\n",
    "\n",
    "from fastNLP import Trainer\n",
    "import torch\n",
    "device= 5 if torch.cuda.is_available() else 'cpu'\n",
    "trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer, batch_size=12,\n",
    "                    dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device)\n",
    "trainer.train()\n",
    "\n",
    "from fastNLP import Tester\n",
    "tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)\n",
    "tester.test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python Now",
   "language": "python",
   "name": "now"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
